// Package obix 实现WebStation N4的客户端
package obix

import (
	"bytes"
	"context"
	"crypto/sha256"
	"crypto/tls"
	"encoding/xml"
	"errors"
	"fmt"
	"io"
	"io/ioutil"
	"net/http"
	"net/http/cookiejar"
	"net/url"
	"strings"
	"sync"
	"time"

	"gitee.com/wqt/obix/types"
	"github.com/lib/pq/scram"
	"golang.org/x/net/publicsuffix"
)

const (
	cookieKeyNiagaraUserid = "niagara_userid"
	cookieKeyJSESSIONID    = "JSESSIONID"
)

type Logger interface {
	Printf(format string, v ...interface{})
}

type asyncReq struct {
	req  *http.Request
	data chan []byte
	err  chan error
}

// Client obix协议的客户端, 使用uri点位请求读取或者相关设备的状态等数据
type Client struct {
	// 服务器的地址
	addr string
	// 用户名
	username string
	// 密码
	password string
	// 本客户端全局唯一的http client
	c *http.Client

	// 日志如果非空将打印请求的调试信息
	// 可以使用WithLogger(Logger)选项设置
	l Logger
	// 不接受重定向,我们根据重定向处理需要登录的问题
	skipRedirect bool

	// 可选项:
	// 请求超时时间
	timeout time.Duration
	// 跳过证书校验
	skipVerifyTLS bool

	lock sync.RWMutex
}

// ClientOption 客户端选项函数
type ClientOption func(c *Client)

func WithTimeout(d time.Duration) ClientOption {
	return func(c *Client) {
		c.timeout = d
	}
}

// WithLogger 添加可选的日志
func WithLogger(l Logger) ClientOption {
	return func(c *Client) {
		c.l = l
	}
}

// WithSkipVerifyTLS 设置跳过tls证书校验
func WithSkipVerifyTLS(v bool) ClientOption {
	return func(c *Client) {
		c.skipVerifyTLS = v
	}
}

// mergeOptions 合并选项
func (c *Client) mergeOptions(opts ...ClientOption) {
	for _, opt := range opts {
		if opt != nil {
			opt(c)
		}
	}
}

var (
	// ErrSkipRedirect 禁用自动跟随重定向
	ErrSkipRedirect = errors.New("skip redirect")
	// ErrUsernameIsEmpty 用户名不能为空
	ErrUsernameIsEmpty = errors.New("username is empty")
	// ErrPasswordIsEmpty 密码不能为空
	ErrPasswordIsEmpty = errors.New("password is empty")
	// ErrNoData 没有数据返回
	ErrNoData = errors.New("no data received")
)

// checkRedirect 客户端不接受重定向
func checkRedirect(req *http.Request, via []*http.Request) error {
	if len(via) >= 1 {
		return ErrSkipRedirect
	}
	return nil
}

// NewClient 创建一个WEBStation N4的obix客户端, addr, username和password都是必须的 opts的类型是ClientOption
func NewClient(ctx context.Context, addr, username, password string, opts ...ClientOption) (*Client, error) {
	if _, err := url.Parse(addr); err != nil {
		return nil, err
	}
	if username == "" {
		return nil, ErrUsernameIsEmpty
	}
	if password == "" {
		return nil, ErrPasswordIsEmpty
	}

	c := &Client{
		addr:         addr,
		username:     username,
		password:     password,
		timeout:      30 * time.Second,
		skipRedirect: true,
	}
	c.mergeOptions(opts...)
	// 初始化http客户端
	c.initHTTPClient()
	// 初始化通道
	return c, nil
}

// initHTTPClient 初始化http客户端
func (c *Client) initHTTPClient() {
	jar, _ := cookiejar.New(&cookiejar.Options{
		PublicSuffixList: publicsuffix.List,
	})
	httpClient := &http.Client{
		// 客户端需要存储并使用cookie
		Jar: jar,
		// 超时时间
		Timeout: c.timeout,
	}
	if c.skipRedirect {
		httpClient.CheckRedirect = checkRedirect
	}
	if c.skipVerifyTLS {
		httpClient.Transport = &http.Transport{
			TLSClientConfig: &tls.Config{
				InsecureSkipVerify: true,
			},
		}
	}
	// 设置http客户端
	c.c = httpClient
}

// doReq 执行http请求, 并返回*http.Response, 如果请求发生网路错误, 返回err非空
func (c *Client) doReq(req *http.Request) (resp *http.Response, err error) {
	defer func() {
		if e := recover(); e != nil {
			err = fmt.Errorf("doReq panic: %v", e)
		}
	}()
	action := req.Header.Get("action")
	if action != "" {
		req.Header.Del("action")
	}
	if c.l != nil {
		c.l.Printf("-- request start --")
		c.l.Printf("url = %s", req.URL)
		cs := c.c.Jar.Cookies(req.URL)
		for _, v := range cs {
			c.l.Printf("action = %s req cookie name = %s, value = %s",
				action, v.Name, v.Value)
		}
	}
	req.Header.Set("Referer", c.addr)
	resp, err = c.c.Do(req)

	if c.l != nil {
		if resp != nil {
			c.l.Printf("status = %v", resp.StatusCode)
			if resp.StatusCode == http.StatusFound {
				loc, _ := resp.Location()
				c.l.Printf("location = %v", loc)
			}
		}
	}
	if err != nil {
		return resp, err
	}
	if c.l != nil {
		for _, v := range resp.Cookies() {
			c.l.Printf("action = %s, resp cookie name = %s, value = %s",
				action, v.Name, v.Value)
		}
		c.l.Printf("-- request end --")
	}
	return resp, err
}

// login 执行登录, 登录时调用c.lock读写锁
func (c *Client) login(ctx context.Context) (err error) {
	defer func() {
		if e := recover(); e != nil {
			err = fmt.Errorf("login panic: %v", e)
		}
	}()
	// 使用读写锁进行登录
	c.lock.Lock()
	defer c.lock.Unlock()
	if err := c.preLogin(ctx); err != nil {
		return err
	}
	sc := scram.NewClient(sha256.New, c.username, c.password)
	scOut, err := c.doScramFirstAction(ctx, sc)
	if err != nil {
		return err
	}
	if c.l != nil {
		c.l.Printf("scOut first = %s", scOut)
	}
	scOut, err = c.doScramSecondAction(ctx, sc, scOut)
	if err != nil {
		return err
	}
	if c.l != nil {
		c.l.Printf("scOut second = %s\n", scOut)
	}
	return c.doScramFinal(ctx, sc, scOut)
}

func (c *Client) preLogin(ctx context.Context) error {
	uri := c.addr + "/login"
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
	if err != nil {
		return err
	}
	req.Header.Add("action", "preLogin")
	req.AddCookie(&http.Cookie{
		Name:  "niagara_userid",
		Value: c.username,
	})
	_, err = c.doReq(req)
	return err
}

func (c *Client) doScramFirstAction(ctx context.Context, sc *scram.Client) ([]byte, error) {
	sc.Step(nil)
	if err := sc.Err(); err != nil {
		return nil, err
	}
	out := sc.Out()
	securityCheckUrl := c.addr + "/j_security_check"
	fistAction := "action=sendClientFirstMessage&clientFirstMessage=" + string(out)
	req, err := http.NewRequestWithContext(ctx,
		http.MethodPost,
		securityCheckUrl,
		strings.NewReader(fistAction),
	)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/x-niagara-login-support")
	req.Header.Set("action", "first")
	resp, err := c.doReq(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	return ioutil.ReadAll(resp.Body)
}

func (c *Client) doScramSecondAction(ctx context.Context, sc *scram.Client, in []byte) ([]byte, error) {
	sc.Step(in)
	if err := sc.Err(); err != nil {
		return nil, err
	}
	out := sc.Out()
	securityCheckUrl := c.addr + "/j_security_check"
	finalAction := "action=sendClientFinalMessage&clientFinalMessage=" + string(out)
	req, err := http.NewRequestWithContext(ctx,
		http.MethodPost,
		securityCheckUrl,
		strings.NewReader(finalAction),
	)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/x-niagara-login-support")
	req.Header.Set("action", "second")
	resp, err := c.doReq(req)
	if err != nil {
		return nil, err
	}
	defer resp.Body.Close()
	return ioutil.ReadAll(resp.Body)
}

func (c *Client) doScramFinal(ctx context.Context, sc *scram.Client, in []byte) error {
	sc.Step(in)
	if err := sc.Err(); err != nil {
		return err
	}
	uri := c.addr + "/j_security_check"
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
	if err != nil {
		return err
	}
	req.Header.Add("action", "j_security_check_get")
	resp, err := c.doReq(req)
	if err != nil {
		if !errors.Is(err, ErrSkipRedirect) {
			return err
		}
	}
	defer resp.Body.Close()
	return nil
}

// checkLogin 检查是否需要登录
func (c *Client) checkLogin(ctx context.Context, resp *http.Response, err error) (needLogin bool, errOut error) {
	defer func() {
		if e := recover(); e != nil {
			errOut = fmt.Errorf("checkLogin panic: %v", e)
		}
	}()
	if err == nil || !errors.Is(err, ErrSkipRedirect) {
		errOut = err
	}
	if resp != nil && resp.StatusCode == http.StatusFound {
		loc, err := resp.Location()
		if err != nil {
			return false, err
		}
		if loc.Path == "/login" {
			needLogin = true
		}
	}
	if needLogin {
		defer resp.Body.Close()
		io.Copy(ioutil.Discard, resp.Body)
	}
	return
}

// doReqWithCheckLogin 发送请求, 如果被重定向到登录页,则执行登录后再发送一次
func (c *Client) doReqWithCheckLogin(ctx context.Context, req *http.Request) ([]byte, error) {
	// 使用读锁发起请求
	c.lock.RLock()
	resp, err := c.doReq(req)
	needLogin, err := c.checkLogin(ctx, resp, err)
	if err != nil {
		c.lock.RUnlock()
		return nil, err
	}
	c.lock.RUnlock()
	// 需要登录后重新发送请求
	if needLogin {
		if err := c.login(ctx); err != nil {
			return nil, err
		}
		resp, err = c.doReq(req)
		if err != nil {
			return nil, err
		}
	}
	if resp == nil {
		return nil, errors.New("no data received")
	}
	defer resp.Body.Close()
	return ioutil.ReadAll(resp.Body)
}

func (c *Client) getUri(s string) string {
	if strings.HasPrefix(s, c.addr) {
		return s
	}
	trimSeq := func(r rune) bool {
		return r == '/'
	}
	return strings.TrimRightFunc(c.addr, trimSeq) +
		"/" +
		strings.TrimLeftFunc(s, trimSeq)
}

// Get 获取指定数据, 如果服务器返回有效数据, 返回obj非空
// path不包含如"http://localhost:8000"的主机地址,
func (c *Client) Get(ctx context.Context, path string) (obj *types.Obj, err error) {
	uri := c.getUri(path)
	req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri, nil)
	if err != nil {
		return
	}
	b, err := c.doReqWithCheckLogin(ctx, req)
	if err != nil {
		return
	}
	if len(b) == 0 {
		return nil, nil
	}
	err = xml.Unmarshal(b, &obj)
	return
}

// Post 发送数据, 如果服务器返回有效数据, 返回obj非空
// path不包含如"http://localhost:8000"的主机地址,
func (c *Client) Post(ctx context.Context, path string, v interface{}) (obj *types.Obj, err error) {
	var r io.Reader
	// 允许不发送body
	if v != nil {
		b, err := xml.Marshal(v)
		if err != nil {
			return nil, err
		}
		if c.l != nil {
			c.l.Printf("data = %s", b)
		}
		r = bytes.NewReader(b)
	}
	uri := c.getUri(path)
	req, err := http.NewRequestWithContext(ctx, http.MethodPost, uri, r)
	if err != nil {
		return nil, err
	}
	req.Header.Set("Content-Type", "application/xml; charset=utf-8")
	b, err := c.doReqWithCheckLogin(ctx, req)
	if err != nil {
		return nil, err
	}
	if len(b) == 0 {
		return nil, nil
	}
	err = xml.Unmarshal(b, &obj)
	return
}
